using namespace metal;

struct Matrices
{
    float4x4 viewProjectionMatrix;
    float4x4 viewMatrix;
    float4x4 normalMatrix;
};

struct LightAttributes
{
    float4 position;
    float4 color;
    float4 spotDirection;
    float diffuseIntensity;
    float specularIntensity;
    float spotCutoff;
    float spotExponent;
};

struct Material
{
    float4 ambientColor;
    float4 diffuseColor;
    float4 specularColor;
    float4 emissiveColor;
    float4 highlightColor;
    float shininess;
    float transparency;
    uint32_t lightingMode;
    float padding;
};

constant bool normal_defined [[function_constant(0)]];
constant bool texCoord_defined [[function_constant(1)]];
constant bool color_defined [[function_constant(2)]];
constant float gamma [[function_constant(3)]];

struct VertexIn
{
    float3 position [[attribute(0)]];
    float3 normal [[attribute(1), function_constant(normal_defined)]];
    float2 texCoord [[attribute(2), function_constant(texCoord_defined)]];
    float3 color [[attribute(2), function_constant(color_defined)]];
};

struct FragmentIn
{
    float4 position [[position]];
    float3 normal;
    float2 texCoord;
    float3 color;
    float3 positionCS;
	float pointSize [[point_size]];
};

float srgbToLinear( float s )
{
	if( s <= 0.0404482362771082f )
		return s / 12.92f;
	else
		return pow( (s + 0.055f) / 1.055f, 2.4f );
}

float4 srgbToLinear( float4 color )
{
//	return float4( pow( color.rgb, gamma ), color.a );
	return float4( srgbToLinear( color.r ), srgbToLinear( color.g ), srgbToLinear( color.b ), color.a );
}

float linearToSRGB( float l )
{
	if( l <= 0.00313066844250063f )
		return l * 12.92f;
	else
		return 1.055 * pow( l, 1 / 2.4f ) - 0.055f;
}

float4 linearToSRGB( float4 color )
{
//	return float4( pow( color.rgb, 1.0 / gamma ), color.a );
	return float4( linearToSRGB( color.r ), linearToSRGB( color.g ), linearToSRGB( color.b ), color.a );
}

float4 ComputePixelColor( float3 normal, float3 positionCS, Material material, LightAttributes light, float4 texColor = float4(1) )
{
    float3 surfaceColor(0);
    float3 lightSpecularColor(0);
    
	//Global ambient lighting
	surfaceColor = material.ambientColor.rgb;
	
    if( material.lightingMode != 0 && (normal.x != 0 || normal.y != 0 || normal.z != 0) )
    {
        //Material emissive
        surfaceColor += material.emissiveColor.rgb;
        float3 N = normalize( normal );
        LightAttributes l = light;
        {
            float3 lightDir;
            if( l.position.w == 0 )
                lightDir = normalize( l.position.xyz );
            else
                lightDir = normalize( l.position.xyz - positionCS.xyz );
            //Diffuse lighting
            float diffuseAtten = dot( lightDir, N );
            //Spot light computation
            float spotAtten = 1;
            if( l.position.w != 0 && l.spotCutoff > 0 )
            {
                spotAtten = -dot( l.spotDirection.xyz, lightDir );
                if( spotAtten > l.spotCutoff )
                    spotAtten = pow( spotAtten, l.spotExponent );
                else
                    spotAtten = 0;
                diffuseAtten *= spotAtten;
            }
            if( material.lightingMode == 2 && diffuseAtten < 0 )
                //Two-sided lighting enabled and lighting on back side of normal, treat as front lit
                diffuseAtten = -diffuseAtten;
            surfaceColor += max( diffuseAtten, 0.0f ) * l.diffuseIntensity * l.color.rgb * material.diffuseColor.rgb;
            //Specular lighting
            if( material.shininess > 0 )
            {
                float3 eyeDir = -normalize( positionCS.xyz );
                float specAtten = pow( max( dot( -reflect( lightDir, N ), eyeDir ), 0.0f ), (1.0 - material.shininess) * 128 );
                lightSpecularColor += specAtten * spotAtten * l.specularIntensity * l.color.rgb;
            }
        }
        surfaceColor = clamp( surfaceColor, 0.0f, 1.0f );
        lightSpecularColor = clamp( lightSpecularColor, 0.0f, 1.0f );
    }
	
    float alpha = (1 - material.transparency) * texColor.a;
//    surfaceColor *= linearToSRGB( texColor ).rgb;
    surfaceColor *= texColor.rgb;

    float4 fragmentColor;
    if( material.lightingMode != 0 )
        //Specular and transparency
        fragmentColor = float4(min( surfaceColor + lightSpecularColor * material.specularColor.rgb, 1.0 ), alpha);
    else
        //Transparency
        fragmentColor = float4(surfaceColor, alpha);
	
	if( material.highlightColor.a > 0.0 )
		fragmentColor.rgb = material.highlightColor.a * material.highlightColor.rgb + (1.0 - material.highlightColor.a) * fragmentColor.rgb;
	
//	fragmentColor = srgbToLinear( fragmentColor );

    return fragmentColor;
}

//FORWARD

vertex FragmentIn vertex_main(VertexIn IN [[stage_in]], constant Matrices& matrices [[buffer(1)]] /*ARGUMENTS*/)
{
//EARLY
    FragmentIn OUT;
    OUT.position = matrices.viewProjectionMatrix * float4(IN.position, 1);
	OUT.pointSize = 2.0f;
    if(normal_defined)
        OUT.normal = (matrices.normalMatrix * float4(IN.normal, 0)).xyz;
    OUT.positionCS = (matrices.viewMatrix * float4(IN.position, 1)).xyz;
    if(texCoord_defined)
        OUT.texCoord = IN.texCoord;
	else if(color_defined)
		OUT.color = IN.color;
//RETURN
    return OUT;
}

fragment float4 fragment_main(FragmentIn IN [[stage_in]], texture2d<float> Texture [[texture(0)]], sampler Sampler [[sampler(0)]], constant Material& material [[buffer(0)]], constant LightAttributes& light [[buffer(1)]] /*ARGUMENTS*/)
{
//EARLY
    float4 texColor;
    if(texCoord_defined)
        texColor = linearToSRGB( Texture.sample( Sampler, IN.texCoord) );
	else if(color_defined)
	{
		//Apply gamma correction.
//		texColor = srgbToLinear( float4( IN.color, 1.0 ) );
		texColor = float4( IN.color, 1.0 );
	}
    else
        texColor = float4(1);
    float3 normal(0);
    if(normal_defined)
        normal = IN.normal;
	float4 OUT = ComputePixelColor( normal, IN.positionCS, material, light, texColor );
//RETURN
    return srgbToLinear( OUT );
}
